package com.ml4ai.junit;

import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

import java.util.Arrays;
import java.util.List;

/**
 * Created by leecheng on 2018/5/18.
 */
public class NDArray {

    @Test
    public void testNDArray() {
        List<Integer> shape = Arrays.asList(5, 4, 3, 2);
        int[] shapes = new int[shape.size()];
        for (int i = 0; i < shape.size(); i++) {
            shapes[i] = shape.get(i);
        }
        Integer product = shape.stream().reduce((a, b) -> a * b).get();
        double[] val = new double[product];
        for (int i = 0; i < product; i++) {
            val[i] = i;
        }
        INDArray nd = Nd4j.create(val);
        nd = nd.reshape(shapes);

        System.out.println("原始：");
        System.out.println(nd);

        INDArrayIndex[] tensorLocation = {NDArrayIndex.point(1), NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.all()};
        INDArray subFeatures = nd.get(tensorLocation);
        System.out.println("截断");
        System.out.println(subFeatures);

        subFeatures.addi(1);
        System.out.println("改变子集");
        System.out.println(subFeatures.transpose());

        nd.put(tensorLocation, subFeatures.transpose());
        System.out.println("回填");
        System.out.println(nd);

    }

    @Test
    public void testTranspose() {
        INDArray z = Nd4j.create(3, 5);
        System.out.println(z);
        System.out.printf("%s", z.transpose());
    }

    @Test
    public void testStack() {
        INDArray a = Nd4j.rand(2, 5);
        INDArray cd = Nd4j.rand(2, 5);
        System.out.println(Nd4j.concat(1, a, cd));
        System.out.print(Nd4j.vstack(a, cd));
    }

    @Test
    public void testReshape() {
        INDArray x = Nd4j.create(new double[]{0.5});
        System.out.println(x);
        INDArray z = x.reshape(1);
        System.out.println(z);

        INDArray a = Nd4j.rand(new int[]{2, 32});
        INDArray b = a.reshape(2, 2, 4, 4);
        INDArray c = b.reshape(1, 2, 32);
        System.out.println(a);
        System.out.println(b);
        System.out.println(c);

    }

    @Test
    public void testConcat() {
        INDArray a = Nd4j.create(new double[]{1, 2});
        INDArray b = Nd4j.create(new double[]{3, 4});
        INDArray c = Nd4j.create(2, 1, 2);
        c.put(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.all()}, a);
        c.put(new INDArrayIndex[]{NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.all()}, b);
        System.out.println(c);
    }



}
